{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8347485f",
   "metadata": {},
   "source": [
    "# How to Build a Reverse Video Search Engine\n",
    "\n",
    "This notebook illustrates how to build a reverse-video-search engine from scratch using [Milvus](https://milvus.io/) and [Towhee](https://towhee.io/).\n",
    "\n",
    "**What is Reverse Video Search?**\n",
    "\n",
    "Reverse video search is similar like [reverse image search](https://en.wikipedia.org/wiki/Reverse_image_search). In simple words, it takes a video as input to search for similar videos. As we know that video-related tasks are harder to tackle, video models normally do not achieve as high scores as other types of models. However, there are increasing demands in AI applications in video. Reverse video search can effectively discover related videos and improve other applications.\n",
    "\n",
    "\n",
    "**What are Milvus & Towhee?**\n",
    "\n",
    "- Milvus is the most advanced open-source vector database built for AI applications and supports nearest neighbor embedding search across tens of millions of entries.\n",
    "- Towhee is a framework that provides ETL for unstructured data using SoTA machine learning models.\n",
    "\n",
    "We will go through the procedure of building a reverse-video-search engine and evaluate its performance."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "569571ec",
   "metadata": {},
   "source": [
    "## Preparation\n",
    "\n",
    "### Install packages\n",
    "\n",
    "Make sure you have installed required python packages:\n",
    "\n",
    "| package |\n",
    "| -- |\n",
    "| towhee |\n",
    "| towhee.models |\n",
    "| pillow |\n",
    "| ipython |\n",
    "| gradio |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d2d8e3e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "! python -m pip install -q towhee towhee.models pillow ipython gradio"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11ef6b1a",
   "metadata": {},
   "source": [
    "### Prepare data\n",
    "\n",
    "This tutorial will use a small data extracted from [Kinetics400](https://www.deepmind.com/open-source/kinetics). You can download the subset from [Github](https://github.com/towhee-io/examples/releases/download/data/reverse_video_search.zip). \n",
    "\n",
    "The data is organized as follows:\n",
    "- **train:** candidate videos, 20 classes, 10 videos per class (200 in total)\n",
    "- **test:** query videos, same 20 classes as train data, 1 video per class (20 in total)\n",
    "- **reverse_video_search.csv:** a csv file containing an ***id***, ***path***, and ***label*** for each video in train data\n",
    "\n",
    "Let's take a quick look:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "54568b1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "! curl -L https://github.com/towhee-io/examples/releases/download/data/reverse_video_search.zip -O\n",
    "! unzip -q -o reverse_video_search.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "490d1379",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>path</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>./train/country_line_dancing/bTbC3w_NIvM.mp4</td>\n",
       "      <td>country_line_dancing</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>./train/country_line_dancing/n2dWtEmNn5c.mp4</td>\n",
       "      <td>country_line_dancing</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>./train/country_line_dancing/zta-Iv-xK7I.mp4</td>\n",
       "      <td>country_line_dancing</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   id                                          path                 label\n",
       "0   0  ./train/country_line_dancing/bTbC3w_NIvM.mp4  country_line_dancing\n",
       "1   1  ./train/country_line_dancing/n2dWtEmNn5c.mp4  country_line_dancing\n",
       "2   2  ./train/country_line_dancing/zta-Iv-xK7I.mp4  country_line_dancing"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "df = pd.read_csv('./reverse_video_search.csv')\n",
    "df.head(3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2171f2e7",
   "metadata": {},
   "source": [
    "For later steps to easier get videos & measure results, we build some helpful functions in advance:\n",
    "- **ground_truth:** get ground-truth video ids for the query video by its path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "dd1b0ef0",
   "metadata": {},
   "outputs": [],
   "source": [
    "id_video = df.set_index('id')['path'].to_dict()\n",
    "label_ids = {}\n",
    "for label in set(df['label']):\n",
    "    label_ids[label] = list(df[df['label']==label].id)\n",
    "    \n",
    "\n",
    "def ground_truth(path):\n",
    "    label = path.split('/')[-2]\n",
    "    return label_ids[label]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3e98f62",
   "metadata": {},
   "source": [
    "### Start Milvus\n",
    "\n",
    "Before getting started with the engine, we also need to get ready with Milvus. Please make sure that you have started a [Milvus service](https://milvus.io/docs/install_standalone-docker.md). This notebook uses [milvus 2.2.10](https://milvus.io/docs/v2.2.x/install_standalone-docker.md) and [pymilvus 2.2.11](https://milvus.io/docs/release_notes.md#2210)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a422f673",
   "metadata": {},
   "outputs": [],
   "source": [
    "! python -m pip install -q pymilvus==2.2.11"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60bf1321",
   "metadata": {},
   "source": [
    "Here we prepare a function to work with a Milvus collection with the following parameters:\n",
    "- [L2 distance metric](https://milvus.io/docs/metric.md#Euclidean-distance-L2)\n",
    "- [IVF_FLAT index](https://milvus.io/docs/index.md#IVF_FLAT)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f4fbffa1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility\n",
    "\n",
    "connections.connect(host='localhost', port='19530')\n",
    "\n",
    "def create_milvus_collection(collection_name, dim):    \n",
    "    if utility.has_collection(collection_name):\n",
    "        utility.drop_collection(collection_name)\n",
    "    \n",
    "    fields = [\n",
    "    FieldSchema(name='id', dtype=DataType.INT64, descrition='ids', is_primary=True, auto_id=False),\n",
    "    FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='embedding vectors', dim=dim)\n",
    "    ]\n",
    "    schema = CollectionSchema(fields=fields, description='reverse video search')\n",
    "    collection = Collection(name=collection_name, schema=schema)\n",
    "\n",
    "    # create IVF_FLAT index for collection.\n",
    "    index_params = {\n",
    "        'metric_type':'L2',\n",
    "        'index_type':\"IVF_FLAT\",\n",
    "        'params':{\"nlist\": 400}\n",
    "    }\n",
    "    collection.create_index(field_name=\"embedding\", index_params=index_params)\n",
    "    return collection\n",
    "\n",
    "collection = create_milvus_collection('x3d_m', 2048)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a30aa33c",
   "metadata": {},
   "source": [
    "## Build Engine\n",
    "\n",
    "Now we are ready to build a reverse-video-search engine. The basic idea behind reverse video search is to represent each video with an embedding and then perform similarity search by comparing vector distances.\n",
    "\n",
    "As mentioned at the beginning, we use deep learning networks provided by Towhee to extract features and generate embeddings. Milvus is used for vector storage and similarity search.\n",
    "\n",
    "<img src='reverse_video_search.png' alt='reverse_video_search_engine' width=700px/>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61433908",
   "metadata": {},
   "source": [
    "### Load Video Embeddings into Milvus\n",
    "\n",
    "We first generate embeddings for videos with [X3D model](https://arxiv.org/abs/2004.04730) and then insert video embeddings into Milvus. Towhee provides a [method-chaining style API](https://towhee.readthedocs.io/en/main/index.html) so that users can assemble a data processing pipeline with operators. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7b2619b2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using cache found in /Users/chizzy/.cache/torch/hub/facebookresearch_pytorchvideo_main\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of inserted data is 200.\n"
     ]
    }
   ],
   "source": [
    "from towhee import pipe, ops\n",
    "from towhee.datacollection import DataCollection\n",
    "\n",
    "def read_csv(csv_file):\n",
    "    import csv\n",
    "    with open(csv_file, 'r', encoding='utf-8-sig') as f:\n",
    "        data = csv.DictReader(f)\n",
    "        for line in data:\n",
    "            yield line['id'], line['path'], line['label']\n",
    "            \n",
    "\n",
    "insert_pipe = (\n",
    "    pipe.input('csv_path')\n",
    "        .flat_map('csv_path', ('id', 'path', 'label'), read_csv)\n",
    "        .map('id', 'id', lambda x: int(x))\n",
    "        .map('path', 'frames', ops.video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': 16}))\n",
    "        .map('frames', ('labels', 'scores', 'features'), ops.action_classification.pytorchvideo(model_name='x3d_m', skip_preprocess=True))\n",
    "        .map(('id', 'features'), 'insert_res', ops.ann_insert.milvus_client(host='127.0.0.1', port='19530', collection_name='x3d_m'))\n",
    "        .output()\n",
    ")\n",
    "\n",
    "insert_pipe('reverse_video_search.csv')\n",
    "print('Total number of inserted data is {}.'.format(collection.num_entities))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b987a6e",
   "metadata": {},
   "source": [
    "#### Pipeline Explanation\n",
    "\n",
    "Here are some details for each line of the assemble pipeline:\n",
    "\n",
    "- `flat_map('csv_path', ('id', 'path', 'label'), read_csv)`: read tabular data from csv file\n",
    "\n",
    "- `map('id', 'id', lambda x: int(x))`: for each row from the data, convert data type of the column id to int\n",
    "\n",
    "- `map('path', 'frames', ops.video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': 16}))`: an embeded Towhee operator reading video as frames with specified sample method and number of samples. [learn more](https://towhee.io/video-decode/ffmpeg)\n",
    "\n",
    "- `map('frames', ('labels', 'scores', 'features'), ops.action_classification.pytorchvideo(model_name='x3d_m', skip_preprocess=True))`: an embeded Towhee operator applying specified model to video frames, which can be used to generate video embedding. [learn more](https://towhee.io/action-classification/pytorchvideo)\n",
    "\n",
    "- `map(('id', 'features'), 'insert_res', ops.ann_insert.milvus_client(host='127.0.0.1', port='19530', collection_name='x3d_m'))`: insert video embedding into Milvus collection"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fcd97399",
   "metadata": {},
   "source": [
    "### Query Similar Videos from Milvus\n",
    "\n",
    "Now all embeddings of candidate videos have been inserted into Milvus collection, we can query embeddings across the collection for nearest neighbors.\n",
    "\n",
    "To get query embeddings, we should go through same pre-insert steps for each input video. Because Milvus returns video ids and vector distances, we use the `id_video` dictionary to get corresponding video paths based on ids."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f0e31a67",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using cache found in /Users/chizzy/.cache/torch/hub/facebookresearch_pytorchvideo_main\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table style=\"border-collapse: collapse;\"><tr><th style=\"text-align: center; font-size: 130%; border: none;\">path</th> <th style=\"text-align: center; font-size: 130%; border: none;\">candidates</th></tr> <tr><td style=\"text-align: center; vertical-align: center; border-right: solid 1px #D3D3D3; border-left: solid 1px #D3D3D3; \">./test/eating_carrots/ty4UQlowp0c.mp4</td> <td style=\"text-align: center; vertical-align: center; border-right: solid 1px #D3D3D3; border-left: solid 1px #D3D3D3; \"><br>./train/eating_carrots/V7DUq0JJneY.mp4</br> <br>./train/eating_carrots/bTCznQiu0hc.mp4</br> <br>./train/eating_carrots/Ou1w86qEr58.mp4</br> <br>./train/eating_carrots/Ka6z9NtiVMQ.mp4</br> <br>./train/eating_carrots/9OZhQqMhX50.mp4</br> <br>./train/eating_carrots/WkwzsrDd-Ws.mp4</br> <br>./train/eating_hotdog/hyt69aadDDU.mp4</br> <br>./train/eating_carrots/s4VofFq0NrI.mp4</br> <br>./train/eating_hotdog/WPc63G4sQU4.mp4</br> <br>./train/eating_carrots/sLQnKz6qv08.mp4</br></td></tr></table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "collection.load()\n",
    "\n",
    "query_path = './test/eating_carrots/ty4UQlowp0c.mp4'\n",
    "\n",
    "query_pipe = (\n",
    "    pipe.input('path')\n",
    "        .map('path', 'frames', ops.video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': 16}))\n",
    "        .map('frames', ('labels', 'scores', 'features'), ops.action_classification.pytorchvideo(model_name='x3d_m', skip_preprocess=True))\n",
    "        .map('features', 'result', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', collection_name='x3d_m', limit=10))  \n",
    "        .map('result', 'candidates', lambda x: [id_video[i[0]] for i in x])\n",
    "        .output('path', 'candidates')\n",
    ")\n",
    "\n",
    "res = DataCollection(query_pipe(query_path))\n",
    "res.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8cd96c30",
   "metadata": {},
   "source": [
    "To display in the notebook, we convert videos to gifs. The code below first loads each video from its path and then gets full video frames with the embeded Towhee operator `.video_decode.ffmpeg()`. Finally converted gifs are saved under the directory *tmp_dir*. The section below is just help to show a search example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b58bfe4d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "Query video \"eating_carrots\": <br/><img src=\"./tmp/ty4UQlowp0c.gif\"> <br/>Top 3 search results: <br/><img src=\"./tmp/V7DUq0JJneY.gif\" style=\"display:inline;margin:1px\"/><img src=\"./tmp/bTCznQiu0hc.gif\" style=\"display:inline;margin:1px\"/><img src=\"./tmp/Ou1w86qEr58.gif\" style=\"display:inline;margin:1px\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "from IPython import display\n",
    "from PIL import Image\n",
    "\n",
    "tmp_dir = './tmp'\n",
    "os.makedirs(tmp_dir, exist_ok=True)\n",
    "\n",
    "def video_to_gif(video_path):\n",
    "    gif_path = os.path.join(tmp_dir, video_path.split('/')[-1][:-4] + '.gif')\n",
    "    p = (\n",
    "        pipe.input('path')\n",
    "            .map('path', 'frames', ops.video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': 16}))\n",
    "            .output('frames')\n",
    "    )\n",
    "    frames = p(video_path).get()[0]\n",
    "    imgs = [Image.fromarray(frame) for frame in frames]\n",
    "    imgs[0].save(fp=gif_path, format='GIF', append_images=imgs[1:], save_all=True, loop=0)\n",
    "    return gif_path\n",
    "\n",
    "html = 'Query video \"{}\": <br/>'.format(query_path.split('/')[-2])\n",
    "query_gif = video_to_gif(query_path)\n",
    "html_line = '<img src=\"{}\"> <br/>'.format(query_gif)\n",
    "html +=  html_line\n",
    "html += 'Top 3 search results: <br/>'\n",
    "\n",
    "for path in res[0]['candidates'][:3]:\n",
    "    gif_path = video_to_gif(path)\n",
    "    html_line = '<img src=\"{}\" style=\"display:inline;margin:1px\"/>'.format(gif_path)\n",
    "    html +=  html_line\n",
    "display.HTML(html)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "92cc4ef9",
   "metadata": {},
   "source": [
    "### Evaluation\n",
    "\n",
    "We have just built a reverse video search engine. But how's its performance? We can evaluate the search engine against the ground truths.\n",
    "\n",
    "In this section, we'll measure the performance with 2 metrics - mHR and mAP:\n",
    "\n",
    "- **mHR (recall@K):**\n",
    "    - Mean Hit Ratio describes how many actual relevant results are returned out of all ground truths.\n",
    "    - Since Milvus return results with topK, we can also call this metric *recall@K*, where K is the count of searched results. When returned results are as many as ground truths, the hit ratio is equivalent to accuracy and we can take it as *accuracy@K* as well.\n",
    "    - For example, there are 100 archery videos in the collection. Then querying the engine with another archery video returns 70 archery videos out of 80 results. In this case, the number of ground truths is 100 and hitted (correct) results are 70. So the hit ratio is 70/100.\n",
    "\n",
    "- **mAP:**\n",
    "    - Average precision describes whether all of the relevant results are ranked higher than irrelevant results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "6d468a31",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using cache found in /Users/chizzy/.cache/torch/hub/facebookresearch_pytorchvideo_main\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table style=\"border-collapse: collapse;\"><tr><th style=\"text-align: center; font-size: 130%; border: none;\">mHR</th> <th style=\"text-align: center; font-size: 130%; border: none;\">mAP</th></tr> <tr><td style=\"text-align: center; vertical-align: center; border-right: solid 1px #D3D3D3; border-left: solid 1px #D3D3D3; \">0.57</td> <td style=\"text-align: center; vertical-align: center; border-right: solid 1px #D3D3D3; border-left: solid 1px #D3D3D3; \">0.6604821428571429</td></tr></table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import glob\n",
    "\n",
    "def mean_hit_ratio(actual, predicted):\n",
    "    ratios = []\n",
    "    for act, pre in zip(actual, predicted):\n",
    "        hit_num = len(set(act) & set(pre))\n",
    "        ratios.append(hit_num / len(act))\n",
    "    return sum(ratios) / len(ratios)\n",
    "\n",
    "def mean_average_precision(actual, predicted):\n",
    "    aps = []\n",
    "    for act, pre in zip(actual, predicted):\n",
    "        precisions = []\n",
    "        hit = 0\n",
    "        for idx, i in enumerate(pre):\n",
    "            if i in act:\n",
    "                hit += 1\n",
    "            precisions.append(hit / (idx + 1))\n",
    "        aps.append(sum(precisions) / len(precisions))\n",
    "    \n",
    "    return sum(aps) / len(aps)\n",
    "\n",
    "eval_pipe = (\n",
    "    pipe.input('path')\n",
    "        .flat_map('path', 'path', lambda x: glob.glob(x))\n",
    "        .map('path', 'frames', ops.video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': 16}))\n",
    "        .map('frames', ('labels', 'scores', 'features'), ops.action_classification.pytorchvideo(model_name='x3d_m', skip_preprocess=True))\n",
    "        .map('features', 'result', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', collection_name='x3d_m', limit=10))  \n",
    "        .map('result', 'predict', lambda x: [i[0] for i in x])\n",
    "        .map('path', 'ground_truth', ground_truth)\n",
    "        .window_all(('ground_truth', 'predict'), 'mHR', mean_hit_ratio)\n",
    "        .window_all(('ground_truth', 'predict'), 'mAP', mean_average_precision)\n",
    "        .output('mHR', 'mAP')\n",
    ")\n",
    "\n",
    "res = DataCollection(eval_pipe('./test/*/*.mp4'))\n",
    "res.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a80f350",
   "metadata": {},
   "source": [
    "## Optimization\n",
    "\n",
    "We can see from above evaluation report, the current performance is not satisfactory. What can we do to improve the search engine? Of course we can fine-tune deep learning network with our own train data. Using more types of embeddings or filters by video tags/description/captions and audio can definitely enhance the search engine as well. But in this tutorial, I will just recommend some very simple options to make improvements.\n",
    "\n",
    "### Normalize embeddings\n",
    "\n",
    "A quick optimization is normalizing all embeddings. Then the L2 distance will be equivalent to cosine similarity, which measures the similarity between two vectors using the angle between them, which ignores the magnitude of the vectors. We use the `ops.towhee.np_normalize` provided by Towhee to simply normalize all embeddings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "04489af1",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using cache found in /Users/chizzy/.cache/torch/hub/facebookresearch_pytorchvideo_main\n",
      "Using cache found in /Users/chizzy/.cache/torch/hub/facebookresearch_pytorchvideo_main\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table style=\"border-collapse: collapse;\"><tr><th style=\"text-align: center; font-size: 130%; border: none;\">mHR</th> <th style=\"text-align: center; font-size: 130%; border: none;\">mAP</th></tr> <tr><td style=\"text-align: center; vertical-align: center; border-right: solid 1px #D3D3D3; border-left: solid 1px #D3D3D3; \">0.66</td> <td style=\"text-align: center; vertical-align: center; border-right: solid 1px #D3D3D3; border-left: solid 1px #D3D3D3; \">0.7379801587301585</td></tr></table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "collection = create_milvus_collection('x3d_m_norm', 2048)\n",
    "\n",
    "insert_pipe = (\n",
    "    pipe.input('csv_path')\n",
    "        .flat_map('csv_path', ('id', 'path', 'label'), read_csv)\n",
    "        .map('id', 'id', lambda x: int(x))\n",
    "        .map('path', 'frames', ops.video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': 16}))\n",
    "        .map('frames', ('labels', 'scores', 'features'), ops.action_classification.pytorchvideo(model_name='x3d_m', skip_preprocess=True))\n",
    "        .map('features', 'features', ops.towhee.np_normalize())\n",
    "        .map(('id', 'features'), 'insert_res', ops.ann_insert.milvus_client(host='127.0.0.1', port='19530', collection_name='x3d_m_norm'))\n",
    "        .output()\n",
    ")\n",
    "\n",
    "insert_pipe('reverse_video_search.csv')\n",
    "\n",
    "collection.load()\n",
    "eval_pipe = (\n",
    "    pipe.input('path')\n",
    "        .flat_map('path', 'path', lambda x: glob.glob(x))\n",
    "        .map('path', 'frames', ops.video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': 16}))\n",
    "        .map('frames', ('labels', 'scores', 'features'), ops.action_classification.pytorchvideo(model_name='x3d_m', skip_preprocess=True))\n",
    "        .map('features', 'features', ops.towhee.np_normalize())\n",
    "        .map('features', 'result', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', collection_name='x3d_m_norm', limit=10))  \n",
    "        .map('result', 'predict', lambda x: [i[0] for i in x])\n",
    "        .map('path', 'ground_truth', ground_truth)\n",
    "        .window_all(('ground_truth', 'predict'), 'mHR', mean_hit_ratio)\n",
    "        .window_all(('ground_truth', 'predict'), 'mAP', mean_average_precision)\n",
    "        .output('mHR', 'mAP')\n",
    ")\n",
    "\n",
    "res = DataCollection(eval_pipe('./test/*/*.mp4'))\n",
    "res.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "053fe927",
   "metadata": {},
   "source": [
    "With vector normalization, we have increased the mHR to 0.66 and mAP to about 0.74, which look better now."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "924813d7",
   "metadata": {},
   "source": [
    "### Change model\n",
    "\n",
    "There are more video models using different networks. Normally a more complicated or larger model will show better results while cost more. You can always try more models to tradeoff among accuracy, latency, and resource usage. Here I show the performance for the reverse video search engine using a SOTA model with [multiscale vision transformer](https://arxiv.org/abs/2104.11227) as backbone."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b10ba34e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using cache found in /Users/chizzy/.cache/torch/hub/facebookresearch_pytorchvideo_main\n",
      "Using cache found in /Users/chizzy/.cache/torch/hub/facebookresearch_pytorchvideo_main\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table style=\"border-collapse: collapse;\"><tr><th style=\"text-align: center; font-size: 130%; border: none;\">mHR</th> <th style=\"text-align: center; font-size: 130%; border: none;\">mAP</th></tr> <tr><td style=\"text-align: center; vertical-align: center; border-right: solid 1px #D3D3D3; border-left: solid 1px #D3D3D3; \">0.785</td> <td style=\"text-align: center; vertical-align: center; border-right: solid 1px #D3D3D3; border-left: solid 1px #D3D3D3; \">0.8285317460317462</td></tr></table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "collection = create_milvus_collection('mvit_base', 768)\n",
    "\n",
    "insert_pipe = (\n",
    "    pipe.input('csv_path')\n",
    "        .flat_map('csv_path', ('id', 'path', 'label'), read_csv)\n",
    "        .map('id', 'id', lambda x: int(x))\n",
    "        .map('path', 'frames', ops.video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': 32}))\n",
    "        .map('frames', ('labels', 'scores', 'features'), ops.action_classification.pytorchvideo(model_name='mvit_base_32x3', skip_preprocess=True))\n",
    "        .map('features', 'features', ops.towhee.np_normalize())\n",
    "        .map(('id', 'features'), 'insert_res', ops.ann_insert.milvus_client(host='127.0.0.1', port='19530', collection_name='mvit_base'))\n",
    "        .output()\n",
    ")\n",
    "\n",
    "insert_pipe('reverse_video_search.csv')\n",
    "\n",
    "collection.load()\n",
    "eval_pipe = (\n",
    "    pipe.input('path')\n",
    "        .flat_map('path', 'path', lambda x: glob.glob(x))\n",
    "        .map('path', 'frames', ops.video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': 32}))\n",
    "        .map('frames', ('labels', 'scores', 'features'), ops.action_classification.pytorchvideo(model_name='mvit_base_32x3', skip_preprocess=True))\n",
    "        .map('features', 'features', ops.towhee.np_normalize())\n",
    "        .map('features', 'result', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', collection_name='mvit_base', limit=10))  \n",
    "        .map('result', 'predict', lambda x: [i[0] for i in x])\n",
    "        .map('path', 'ground_truth', ground_truth)\n",
    "        .window_all(('ground_truth', 'predict'), 'mHR', mean_hit_ratio)\n",
    "        .window_all(('ground_truth', 'predict'), 'mAP', mean_average_precision)\n",
    "        .output('mHR', 'mAP')\n",
    ")\n",
    "\n",
    "res = DataCollection(eval_pipe('./test/*/*.mp4'))\n",
    "res.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe2e5ea9",
   "metadata": {},
   "source": [
    "Switching to MVIT model increases the mHR to 0.79 and mAP to 0.83, which are much better than X3D model. However, both insert and search time have increased. It's time for you to make trade-off between latency and accuracy. You're always encouraged to play around with this tutorial."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "607783a1",
   "metadata": {},
   "source": [
    "## Release a Showcase\n",
    "\n",
    "We've learnt how to build a reverse video search engine. Now it's time to add some interface and release a showcase."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "a78c9ba6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using cache found in /Users/chizzy/.cache/torch/hub/facebookresearch_pytorchvideo_main\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running on local URL:  http://127.0.0.1:7860\n",
      "Running on public URL: https://da02e32b-1569-4ad5.gradio.live\n",
      "\n",
      "This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div><iframe src=\"https://da02e32b-1569-4ad5.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": []
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gradio\n",
    "\n",
    "video_search_pipe = (\n",
    "    pipe.input('path')\n",
    "        .flat_map('path', 'path', lambda x: glob.glob(x))\n",
    "        .map('path', 'frames', ops.video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': 32}))\n",
    "        .map('frames', ('labels', 'scores', 'features'), ops.action_classification.pytorchvideo(model_name='mvit_base_32x3', skip_preprocess=True))\n",
    "        .map('features', 'features', ops.towhee.np_normalize())\n",
    "        .map('features', 'result', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', collection_name='mvit_base', limit=3)) \n",
    "        .map('result', 'predict', lambda x: [id_video[i[0]] for i in x])\n",
    "        .output('predict')\n",
    ")\n",
    "\n",
    "\n",
    "def video_search_function(video):\n",
    "    return video_search_pipe(video).to_list()[0][0]\n",
    "\n",
    "interface = gradio.Interface(video_search_function, \n",
    "                             inputs=gradio.Video(source='upload'),\n",
    "                             outputs=[gradio.Video(format='mp4') for _ in range(3)]\n",
    "                            )\n",
    "\n",
    "interface.launch(inline=True, share=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df8b1890",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  },
  "vscode": {
   "interpreter": {
    "hash": "f7dd10cdbe9a9c71f7e71741efd428241b5f9fa0fecdd29ae07a5706cd5ff8a2"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
